Cluster annotation¶

In [195]:
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [196]:
from scip_workflows.common import *
In [197]:
import anndata
import scanpy
from matplotlib.patches import ConnectionStyle
from sklearn.feature_selection import mutual_info_classif
import warnings
import pickle
from scip_workflows.core import plot_gate_czi
In [216]:
try:
    adata = snakemake.input.adata
    output_three = snakemake.output[0]
    output_cd15_cd45 = snakemake.output[1]
    output_cd15_siglec8 = snakemake.output[2]
    image_root = snakemake.input.image_root
except NameError:
    image_root = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800")
    data_dir = Path("/home/maximl/scratch/data/vsc/datasets/cd7/800/scip/061020221736/")
    adata = data_dir / "adata.pickle"
    output_three = data_dir / "figures" / "cluster_panels.png"
    output_cd15_cd45 = data_dir / "figures" / "cd15_vs_cd45_facets.png"
    output_cd15_siglec8 = data_dir / "figures" / "cd15_vs_siglec8_facets.png"
    output_unclassified = data_dir / "figures" / "unclassified_cluster.png"
In [217]:
def map_names(a):
    return {
        "feat_combined_sum_DAPI": "DAPI",
        "feat_combined_sum_EGFP": "CD45",
        "feat_combined_sum_RPe": "Siglec 8",
        "feat_combined_sum_APC": "CD15"
    }[a]
In [218]:
with open(adata, "rb") as fh:
    adata = pickle.load(fh)
In [219]:
adata.obs.meta_path = adata.obs.meta_path.apply(lambda p: image_root.joinpath(*Path(p).parts[Path(p).parts.index("800")+1:]))
In [220]:
markers = [col for col in adata.var_names if col.startswith(tuple("feat_combined_sum_%s" % m for m in ("EGFP", "RPe", "APC", "DAPI")))]
In [221]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(adata, markers, groupby='leiden', dendrogram=True, vmin=-2, vmax=2, cmap='RdBu_r', ax=axes[0], show=False, use_raw=False)
ax["mainplot_ax"].set_xticklabels(map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels()))
scanpy.pl.umap(adata, color="leiden", legend_loc='on data', ax=axes[1], show=False)
seaborn.countplot(data=adata.obs, x="leiden", hue="meta_replicate", ax=axes[2])
WARNING: dendrogram data not found (using key=dendrogram_leiden). Running `sc.tl.dendrogram` with default parameters. For fine tuning it is recommended to run `sc.tl.dendrogram` independently.
Out[221]:
<AxesSubplot:xlabel='leiden', ylabel='count'>
In [222]:
adata.obs["leiden_merged"] = adata.obs.leiden.map(lambda a: a if a in [str(i) for i in [2,4,6,8]] else '1')
In [223]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
ax = scanpy.pl.matrixplot(adata, markers, groupby='leiden_merged', dendrogram=True, vmin=-2, vmax=2, cmap='RdBu_r', ax=axes[1], show=False, use_raw=False)
ax["mainplot_ax"].set_xticklabels(map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels()))
scanpy.pl.umap(adata, color="leiden_merged", ax=axes[2], show=False)
seaborn.countplot(data=adata.obs, x="leiden_merged", hue="meta_replicate", ax=axes[0])
WARNING: dendrogram data not found (using key=dendrogram_leiden_merged). Running `sc.tl.dendrogram` with default parameters. For fine tuning it is recommended to run `sc.tl.dendrogram` independently.
Out[223]:
<AxesSubplot:xlabel='leiden_merged', ylabel='count'>
In [224]:
scanpy.pl.scatter(adata, x="feat_combined_sum_EGFP", y="feat_combined_sum_APC", color="leiden_merged", legend_loc="on data")
In [234]:
grid = seaborn.FacetGrid(
    data=scanpy.get.obs_df(adata, keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC", "leiden_merged"], use_raw=True),
    col="leiden_merged"
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
    seaborn.scatterplot(data=scanpy.get.obs_df(adata, keys=["feat_combined_sum_EGFP", "feat_combined_sum_APC"], use_raw=True), x="feat_combined_sum_EGFP", y="feat_combined_sum_APC", color="grey", s=.5, alpha=.5, ax=ax)
grid.map_dataframe(seaborn.scatterplot, x="feat_combined_sum_EGFP", y="feat_combined_sum_APC", s=1.5)
for ax in grid.axes.ravel():
    ax.set_yticks([]); ax.set_xticks([])
    ax.set_xlabel("CD45")
    ax.set_ylabel("CD15")
    
plt.savefig(output_cd15_cd45, bbox_inches='tight', pad_inches=0, dpi=200)
In [226]:
scanpy.pl.scatter(adata[adata.obs.leiden.isin(['1', '6', '8', '9'])], x="feat_combined_sum_RPe", y="feat_combined_sum_APC", color="leiden", legend_loc="on data")
In [233]:
grid = seaborn.FacetGrid(
    data=scanpy.get.obs_df(adata[adata.obs.leiden.isin(['1', '6', '8', '9'])], keys=["feat_combined_sum_RPe", "feat_combined_sum_APC", "leiden_merged"], use_raw=True),
    col="leiden_merged"
)
grid.set_titles(col_template="Cluster {col_name}")
for ax in grid.axes.ravel():
    seaborn.scatterplot(data=scanpy.get.obs_df(adata[adata.obs.leiden.isin(['1', '6', '8', '9'])], keys=["feat_combined_sum_RPe", "feat_combined_sum_APC"], use_raw=True), x="feat_combined_sum_RPe", y="feat_combined_sum_APC", color="grey", s=.5, alpha=.5, ax=ax)
grid.map_dataframe(seaborn.scatterplot, x="feat_combined_sum_RPe", y="feat_combined_sum_APC", s=1.5)
for ax in grid.axes.ravel():
    ax.set_yticks([]); ax.set_xticks([])
    ax.set_xlabel("Siglec 8")
    ax.set_ylabel("CD15")
    
plt.savefig(output_cd15_siglec8, bbox_inches='tight', pad_inches=0, dpi=200)

SHAP¶

In [136]:
import shap
shap.initjs()

from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score
from sklearn.ensemble import RandomForestClassifier
In [137]:
X_train, X_test, y_train, y_test = train_test_split(adata[:, adata.var.selected_corr], adata.obs["leiden_merged"], test_size=0.1, stratify=adata.obs["leiden_merged"])
In [138]:
model = RandomForestClassifier(n_estimators=50, random_state=0).fit(X_train.to_df(), y_train.values)
In [139]:
preds = model.predict(X_test.to_df())
balanced_accuracy_score(y_test.values, preds)
Out[139]:
0.8489833382280905
In [140]:
explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test.to_df())
In [141]:
y_train.cat.categories
Out[141]:
Index(['1', '2', '4', '6', '8', '9'], dtype='object')
In [150]:
shap.plots.beeswarm(shap_values[..., 3])
In [151]:
adata.obs["meta_masks"] = adata.obs[["meta_scene", "meta_tile"]].apply(lambda r: str(data_dir / "masks" / "%s_%s.npy") % (r.meta_scene, r.meta_tile), axis=1)
In [152]:
plot_gate_czi(
    sel=adata.obs["leiden"] == '6',
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6], maxn=50,
    masks_path_col="meta_masks"
)
0 P1-D5 0 P2-D5 0 P3-D3 0 P3-D4 0 P3-D5 0 P4-D2 0 P4-D3 0 P5-D2 0 P5-D4 0 P7-D3 0 P8-D3 0 P8-D5 0 P9-D2 0 P9-D4 0 P10-D2 0 P10-D5 0 P12-D4 0 P13-D3 0 P14-D5 0 P15-D1 0 P15-D4 0 P15-D5 0 P16-D4 0 P17-D1 0 P17-D2 0 P18-D2 0 P19-D1 0 P19-D3 0 P19-D5 0 P20-D1 0 P20-D2 0 P20-D3 0 P22-D1 0 P22-D2 0 P23-D4 0 P24-D3 0 P24-D5 
In [236]:
plot_gate_czi(
    sel=adata.obs["leiden"] == '6',
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6], maxn=50
)
plt.savefig(output_unclassified)
0 P1-D5 0 P3-D3 0 P3-D4 0 P4-D1 0 P4-D2 0 P4-D5 0 P5-D2 0 P6-D4 0 P6-D5 0 P7-D1 0 P8-D4 0 P9-D1 0 P9-D4 0 P9-D5 0 P10-D2 0 P10-D3 0 P10-D4 0 P10-D5 0 P11-D4 0 P14-D1 0 P14-D2 0 P14-D3 0 P14-D5 0 P15-D3 0 P15-D5 0 P17-D1 0 P18-D4 0 P18-D5 0 P19-D1 0 P19-D3 0 P20-D2 0 P21-D1 0 P21-D3 0 P22-D5 0 P23-D4 0 P24-D1 0 P24-D2 0 P24-D3 0 P24-D5 0 P25-D2 
In [157]:
quantiles = adata.to_df().filter(regex="feat_combined_sum").quantile([0.05,0.95])
extent = quantiles.loc[:, ["feat_combined_sum_%s" % s for s in ["DAPI", "EGFP", "RPe", "APC", "Bright", "Oblique", "PGC"]]].T.values
In [159]:
plot_gate_czi(
    sel=adata.obs["leiden"] == '6',
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6], maxn=50,
    extent=extent
)
0 P1-D4 0 P2-D2 0 P2-D5 0 P3-D1 0 P3-D3 0 P4-D2 0 P4-D3 0 P4-D5 0 P6-D3 0 P6-D4 0 P7-D5 0 P8-D5 0 P9-D1 0 P9-D4 0 P10-D1 0 P10-D2 0 P10-D4 0 P10-D5 0 P12-D5 0 P13-D5 0 P14-D1 0 P14-D5 0 P15-D1 0 P15-D3 0 P15-D4 0 P17-D2 0 P17-D5 0 P18-D3 0 P19-D2 0 P19-D3 0 P20-D1 0 P20-D5 0 P21-D1 0 P21-D3 0 P22-D1 0 P22-D2 0 P23-D1 0 P23-D5 0 P24-D1 0 P24-D2 0 P25-D2 
In [75]:
scanpy.pl.violin(adata, "feat_combined_sum_APC", groupby="leiden_merged")
In [76]:
shap.plots.scatter(shap_values[..., "feat_combined_sum_APC", 4])
In [78]:
shap.plots.beeswarm(shap_values[..., 5])
In [79]:
plot_gate_czi(
    sel=adata.obs["leiden"] == '9',
    df=adata.obs,
    channels=[0, 1, 2, 3, 4, 5, 6], maxn=30,
    masks_path_col="meta_masks"
)
0 P1-D3 0 P2-D2 0 P2-D3 0 P5-D2 0 P5-D3 0 P7-D2 0 P8-D5 0 P10-D3 0 P13-D1 0 P14-D4 0 P17-D4 0 P24-D5 

Cluster annotation¶

In [228]:
# create a dictionary to map cluster to annotation label
cluster2annotation = {
    '1': 'granulocytes',
    '8': 'eosinophils',
    '4': 'monocytes',
    '2': 'lymphocytes',
    '6': 'unclassified'
}

# add a new `.obs` column called `cell type` by mapping clusters to annotation using pandas `map` function
cat_type = pandas.CategoricalDtype(['monocytes', 'lymphocytes', 'granulocytes', 'eosinophils', 'unclassified'], ordered=True)
adata.obs['cell type'] = adata.obs['leiden_merged'].map(cluster2annotation).astype(cat_type)
In [192]:
from matplotlib.gridspec import GridSpec

fig = plt.figure(dpi=200, figsize=(10, 7), constrained_layout=True)
gs = GridSpec(2, 2, figure=fig)

ax = fig.add_subplot(gs[0, 0])
scanpy.pl.scatter(adata[adata.obs.leiden_merged.isin(cluster2annotation.keys())], x="feat_combined_sum_EGFP", y="feat_combined_sum_APC", color="leiden_merged", legend_loc="on data", ax=ax, show=False)
# ax.annotate('monocytes', xy=(.8, -.6), xytext=(3, -1.5), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
# ax.annotate('lymphocytes', xy=(1.2, -1.7), xytext=(2.5, -2.5), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
ax.text(s="A", x=0.02, y=1, fontsize=20, weight="heavy", alpha=0.2, transform=ax.transAxes, va="top")
ax.set_xlabel("CD45")
ax.set_ylabel("CD15")
ax.set_title("")
# ax.set_aspect(0.9)
ax.set_yticks([])
ax.set_xticks([])

# ax.annotate("", xytext=(2.5, 1.8), xy=(5, 2.5), arrowprops=dict(facecolor='grey', width=7, edgecolor="none", connectionstyle=ConnectionStyle("Arc3", rad=-0.2)))

ax2 = fig.add_subplot(gs[1, 0])
scanpy.pl.scatter(adata[adata.obs["cell type"].isin(['granulocytes', 'eosinophils'])], x="feat_combined_sum_RPe", y="feat_combined_sum_APC", color="leiden_merged", legend_loc="on data", ax=ax2, show=False)
# ax2.annotate('eosinophils', xy=(2.15, -.15), xytext=(.7, -1.8), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
# ax2.annotate('neutrophils', xy=(-.75, 1), xytext=(-3, 1), arrowprops=dict(facecolor='black', arrowstyle="-|>"))
ax2.text(s="B", x=0.02, y=1, fontsize=20, weight="heavy", alpha=0.2, transform=ax2.transAxes, va="top")
ax2.set_title("")
ax2.set_ylabel("CD15")
ax2.set_xlabel("Siglec 8")
# ax2.set_aspect(0.9)
ax2.set_yticks([])
ax2.set_xticks([])

ax3 = fig.add_subplot(gs[:, 1])
scanpy.pl.umap(adata, color=["cell type"], legend_loc='on data', ax=ax3, show=False, palette=plt.get_cmap("tab10")([8, 2, 4, 2]).tolist())
ax3.text(s="C", x=0.02, y=1, fontsize=20, weight="heavy", alpha=0.2, transform=ax3.transAxes, va="top")
ax3.set_title("")
ax3.set_aspect(1)

seaborn.despine(fig)

plt.savefig(output, bbox_inches='tight', pad_inches=0, dpi=200)
WARNING: Length of palette colors is smaller than the number of categories (palette length: 4, categories length: 5. Some categories will have the same color.
In [229]:
fig, axes = plt.subplots(1, 3, figsize=(15, 5), tight_layout=True)
ax = scanpy.pl.matrixplot(adata, markers, groupby='cell type', dendrogram=False, vmin=-2, vmax=2, cmap='RdBu_r', ax=axes[1], show=False, use_raw=False)
ax["mainplot_ax"].set_xticklabels(map(lambda a: map_names(a.get_text()), ax["mainplot_ax"].get_xticklabels()))
scanpy.pl.umap(adata, color="cell type", ax=axes[2], show=False, palette="tab10")
seaborn.countplot(data=adata.obs, y="cell type", hue="meta_replicate", ax=axes[0])

axes[0].set_title("Cell type counts")
axes[1].set_title("Marker intensity")
axes[2].set_title("UMAP")
axes[0].legend(title="Replicate")

plt.savefig(output_three, bbox_inches='tight', pad_inches=0, dpi=200)
In [171]:
counts = adata.obs["cell type"].value_counts().to_frame()
counts["fraction"] = counts["cell type"] / counts["cell type"].sum()
counts.columns = ["Count", "Fraction"]
print(counts.style.to_latex())
\begin{tabular}{lrr}
{} & {Count} & {Fraction} \\
neutrophils & 21673 & 0.756818 \\
lymphocytes & 4904 & 0.171247 \\
monocytes & 1737 & 0.060656 \\
eosinophils & 323 & 0.011279 \\
\end{tabular}

In [ ]: